#!/usr/bin/env python

import copy
import rospy
import threading
import numpy as np
from interbotix_xs_msgs.msg import ArmJoy
from interbotix_common_modules import angle_manipulation as ang
from interbotix_xs_modules.arm import InterbotixManipulatorXS

class ArmRobot(object):
    def __init__(self):
        self.waist_step = 0.06
        self.rotate_step = 0.04
        self.translate_step = 0.01
        self.gripper_pressure_step = 0.125
        self.current_loop_rate = 25
        self.current_torque_status = True
        self.current_gripper_pressure = 0.5
        self.loop_rates = {"course" : 25, "fine" : 25}
        self.joy_msg = ArmJoy()
        self.joy_mutex = threading.Lock()
        self.rate = rospy.Rate(self.current_loop_rate)
        robot_model = rospy.get_param("~robot_model")
        robot_name = rospy.get_namespace().strip("/")
        self.armbot = InterbotixManipulatorXS(robot_model, robot_name=robot_name, moving_time=0.2, accel_time=0.1, init_node=False)
        self.num_joints = self.armbot.arm.group_info.num_joints
        self.waist_index = self.armbot.arm.group_info.joint_names.index("waist")
        self.waist_ll = self.armbot.arm.group_info.joint_lower_limits[self.waist_index]
        self.waist_ul = self.armbot.arm.group_info.joint_upper_limits[self.waist_index]
        self.T_sy = np.identity(4)
        self.T_yb = np.identity(4)
        self.update_T_yb()
        rospy.Subscriber("commands/joy_processed", ArmJoy, self.joy_control_cb)

    ### @brief Helper function that updates the frequency at which the main control loop runs
    ### @param loop_rate - desired loop frequency [Hz]
    def update_speed(self, loop_rate):
        self.current_loop_rate = loop_rate
        self.rate = rospy.Rate(self.current_loop_rate)
        rospy.loginfo("Current loop rate is %d Hz." % self.current_loop_rate)

    ### @brief Helper function that calculates the pose of the end-effector w.r.t. T_y
    def update_T_yb(self):
        T_sb = self.armbot.arm.get_ee_pose_command()
        rpy = ang.rotationMatrixToEulerAngles(T_sb[:3, :3])
        self.T_sy[:2,:2] = ang.yawToRotationMatrix(rpy[2])
        self.T_yb = np.dot(ang.transInv(self.T_sy), T_sb)

    ### @brief Helper function that updates the gripper pressure
    ### @param gripper_pressure - desired gripper pressure from 0 - 1
    def update_gripper_pressure(self, gripper_pressure):
        self.current_gripper_pressure = gripper_pressure
        self.armbot.gripper.set_pressure(self.current_gripper_pressure)
        rospy.loginfo("Gripper pressure is at %.2f%%." % (self.current_gripper_pressure * 100.0))

    ### @brief ROS Callback function that recieves ArmJoy messages
    ### @param msg - ArmJoy ROS message
    def joy_control_cb(self, msg):
        with self.joy_mutex:
            self.joy_msg = copy.deepcopy(msg)

        # Check the speed_cmd
        if (msg.speed_cmd == ArmJoy.SPEED_INC and self.current_loop_rate < 40):
            self.update_speed(self.current_loop_rate + 1)
        elif (msg.speed_cmd == ArmJoy.SPEED_DEC and self.current_loop_rate > 10):
            self.update_speed(self.current_loop_rate - 1)

        # Check the speed_toggle_cmd
        if (msg.speed_toggle_cmd == ArmJoy.SPEED_COURSE):
            self.loop_rates["fine"] = self.current_loop_rate
            rospy.loginfo("Switched to Course Control")
            self.update_speed(self.loop_rates["course"])
        elif (msg.speed_toggle_cmd == ArmJoy.SPEED_FINE):
            self.loop_rates["course"] = self.current_loop_rate
            rospy.loginfo("Switched to Fine Control")
            self.update_speed(self.loop_rates["fine"])

        # Check the gripper_cmd
        if (msg.gripper_cmd == ArmJoy.GRIPPER_OPEN):
            self.armbot.gripper.open(delay=0)
        elif (msg.gripper_cmd == ArmJoy.GRIPPER_CLOSE):
            self.armbot.gripper.close(delay=0)

        # Check the gripper_pwm_cmd
        if (msg.gripper_pwm_cmd == ArmJoy.GRIPPER_PWM_INC and self.current_gripper_pressure < 1):
            self.update_gripper_pressure(self.current_gripper_pressure + self.gripper_pressure_step)
        elif (msg.gripper_pwm_cmd == ArmJoy.GRIPPER_PWM_DEC and self.current_gripper_pressure > 0):
            self.update_gripper_pressure(self.current_gripper_pressure - self.gripper_pressure_step)

        # Check the torque_cmd
        if (msg.torque_cmd == ArmJoy.TORQUE_ON):
            self.armbot.dxl.robot_torque_enable("group", "arm", True)
            self.armbot.arm.capture_joint_positions()
            self.update_T_yb()
            self.current_torque_status = True
        elif (msg.torque_cmd == ArmJoy.TORQUE_OFF):
            self.armbot.dxl.robot_torque_enable("group", "arm", False)
            self.current_torque_status = False

    ### @brief Main control loop to manipulate the arm
    def controller(self):

        if (self.current_torque_status == False): return

        with self.joy_mutex:
            msg = copy.deepcopy(self.joy_msg)

        # Check the pose_cmd
        if (msg.pose_cmd != 0):
            if (msg.pose_cmd == ArmJoy.HOME_POSE):
                self.armbot.arm.go_to_home_pose(1.5, 0.75)
            elif (msg.pose_cmd == ArmJoy.SLEEP_POSE):
                self.armbot.arm.go_to_sleep_pose(1.5, 0.75)
            self.update_T_yb()
            self.armbot.arm.set_trajectory_time(0.2, 0.1)

        # Check the waist_cmd
        if (msg.waist_cmd != 0):
            waist_position = self.armbot.arm.get_single_joint_command("waist")
            if (msg.waist_cmd == ArmJoy.WAIST_CCW):
                success = self.armbot.arm.set_single_joint_position("waist", waist_position + self.waist_step, 0.2, 0.1, False)
                if (success == False and waist_position != self.waist_ul):
                    self.armbot.arm.set_single_joint_position("waist", self.waist_ul, 0.2, 0.1, False)
            elif (msg.waist_cmd == ArmJoy.WAIST_CW):
                success = self.armbot.arm.set_single_joint_position("waist", waist_position - self.waist_step, 0.2, 0.1, False)
                if (success == False and waist_position != self.waist_ll):
                    self.armbot.arm.set_single_joint_position("waist", self.waist_ll, 0.2, 0.1, False)
            self.update_T_yb()

        position_changed = msg.ee_x_cmd + msg.ee_z_cmd
        if (self.num_joints >= 6):
            position_changed += msg.ee_y_cmd
        orientation_changed = msg.ee_roll_cmd + msg.ee_pitch_cmd

        if (position_changed + orientation_changed == 0): return

        # Copy the most recent T_yb transform into a temporary variable
        T_yb = np.array(self.T_yb)

        if (position_changed):
            # check ee_x_cmd
            if (msg.ee_x_cmd == ArmJoy.EE_X_INC):
                T_yb[0, 3] += self.translate_step
            elif (msg.ee_x_cmd == ArmJoy.EE_X_DEC):
                T_yb[0, 3] -= self.translate_step

            # check ee_y_cmd
            if (msg.ee_y_cmd == ArmJoy.EE_Y_INC and self.num_joints >= 6 and T_yb[0,3] > 0.3):
                T_yb[1, 3] += self.translate_step
            elif (msg.ee_y_cmd == ArmJoy.EE_Y_DEC and self.num_joints >= 6 and T_yb[0,3] > 0.3):
                T_yb[1, 3] -= self.translate_step

            # check ee_z_cmd
            if (msg.ee_z_cmd == ArmJoy.EE_Z_INC):
                T_yb[2, 3] += self.translate_step
            elif (msg.ee_z_cmd == ArmJoy.EE_Z_DEC):
                T_yb[2, 3] -= self.translate_step

        # check end-effector orientation related commands
        if (orientation_changed != 0):
            rpy = ang.rotationMatrixToEulerAngles(T_yb[:3, :3])

            # check ee_roll_cmd
            if (msg.ee_roll_cmd == ArmJoy.EE_ROLL_CCW):
                rpy[0] += self.rotate_step
            elif (msg.ee_roll_cmd == ArmJoy.EE_ROLL_CW):
                rpy[0] -= self.rotate_step

            # check ee_pitch_cmd
            if (msg.ee_pitch_cmd == ArmJoy.EE_PITCH_DOWN):
                rpy[1] += self.rotate_step
            elif (msg.ee_pitch_cmd == ArmJoy.EE_PITCH_UP):
                rpy[1] -= self.rotate_step

            T_yb[:3,:3] = ang.eulerAnglesToRotationMatrix(rpy)

        # Get desired transformation matrix of the end-effector w.r.t. the base frame
        T_sd = np.dot(self.T_sy, T_yb)
        _, success = self.armbot.arm.set_ee_pose_matrix(T_sd, self.armbot.arm.get_joint_commands(), True, 0.2, 0.1, False)
        if (success):
            self.T_yb = np.array(T_yb)

def main():
    rospy.init_node('xsarm_robot')
    arm = ArmRobot()
    while not rospy.is_shutdown():
        arm.controller()
        arm.rate.sleep()

if __name__=='__main__':
    main()
